
from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path
from typing import Tuple

import numpy as np
import pandas as pd
from scipy.stats import spearmanr

REQUIRED_COLS = ("ground_truth", "mean_prediction")


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Compute Spearman correlation on JSON data.")
    parser.add_argument("input_json", type=str, help="Path to JSON/JSONL file containing the data.")
    parser.add_argument(
        "--output",
        "-o",
        type=str,
        default=None,
        help="Path to write the results (.txt). Defaults to <input>_spearman.txt",
    )
    return parser.parse_args()


def load_json(path: Path) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Load ground truth and predictions from a JSON or JSONL file."""
    if not path.exists():
        sys.exit(f"File not found: {path}")

    try:
        data = json.loads(path.read_text(encoding="utf-8"))
        if isinstance(data, dict):
            data = [data]
        df = pd.DataFrame(data)
    except json.JSONDecodeError:
        try:
            df = pd.read_json(path, lines=True)
        except ValueError as exc:
            sys.exit(f"Failed to parse JSON/JSONL: {exc}")

    for col in REQUIRED_COLS:
        if col not in df.columns:
            sys.exit(f"Required field '{col}' not found in file {path}")

    y_true = df["ground_truth"].to_numpy()
    y_mean = df["mean_prediction"].to_numpy()

    if "median_prediction" in df.columns:
        y_median = df["median_prediction"].to_numpy()
    elif "all_predictions" in df.columns:
        def _safe_median(item):
            if isinstance(item, (list, tuple)) and len(item) > 0:
                try:
                    return float(np.median(item))
                except (TypeError, ValueError):
                    return np.nan
            return np.nan

        y_median = df["all_predictions"].apply(_safe_median).to_numpy()
    elif "agent_predictions" in df.columns:
        def _median_from_agent_dict(agent_dict):
            if isinstance(agent_dict, dict):
                preds = [sub.get("prediction") for sub in agent_dict.values() if isinstance(sub, dict) and sub.get("prediction") is not None]
                preds = [float(p) for p in preds if p is not None]
                if preds:
                    return float(np.median(preds))
            return np.nan

        y_median = df["agent_predictions"].apply(_median_from_agent_dict).to_numpy()
    else:
        y_median = np.full_like(y_true, np.nan, dtype=float)

    return y_true, y_mean, y_median


def compute_spearman(y_true: np.ndarray, y_pred: np.ndarray) -> Tuple[float, float]:
    rho, p_val = spearmanr(y_true, y_pred, nan_policy="omit")
    return float(rho), float(p_val)


def compute_mape(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    """Compute Mean Absolute Percentage Error (MAPE). Ignores zero ground truth values."""
    mask = (y_true != 0)
    if not np.any(mask):
        return float('nan')
    return float(np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask])) * 100)


def compute_mse(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    """Compute Mean Squared Error (MSE)."""
    mask = ~np.isnan(y_pred)
    if not np.any(mask):
        return float('nan')
    return float(np.mean((y_true[mask] - y_pred[mask]) ** 2))


def write_results(
    output_path: Path,
    rho_mean: float,
    p_mean: float,
    rho_median: float | None = None,
    p_median: float | None = None,
    mape_mean: float | None = None,
    mse_mean: float | None = None,
    mape_median: float | None = None,
    mse_median: float | None = None,
    n_samples: int | None = None,
) -> None:
    """Write Spearman, MAPE, and MSE results to *output_path*."""

    lines = []
    if n_samples is not None:
        lines.append(f"Number of samples: {n_samples}\n")
    lines.extend([
        f"Spearman correlation using MEAN predictions (rho): {rho_mean:.6f}\n",
        f"p-value: {p_mean:.6g}\n",
    ])
    if mape_mean is not None:
        lines.append(f"MAPE (mean): {mape_mean:.6f}%\n")
    if mse_mean is not None:
        lines.append(f"MSE (mean): {mse_mean:.6f}\n")

    if rho_median is not None and not np.isnan(rho_median):
        lines.extend(
            [
                "\n",
                f"Spearman correlation using MEDIAN predictions (rho): {rho_median:.6f}\n",
                f"p-value: {p_median:.6g}\n",
            ]
        )
        if mape_median is not None:
            lines.append(f"MAPE (median): {mape_median:.6f}%\n")
        if mse_median is not None:
            lines.append(f"MSE (median): {mse_median:.6f}\n")

    output_path.write_text("".join(lines), encoding="utf-8")
    print(f"Results written to {output_path}")
    print("\n".join([l.strip() for l in lines]))


def main() -> None:
    args = parse_args()
    input_path = Path(args.input_json)
    output_path = Path(args.output) if args.output else input_path.with_suffix(".spearman.txt")

    y_true, y_mean, y_median = load_json(input_path)
    n_samples = len(y_true)

    rho_mean, p_mean = compute_spearman(y_true, y_mean)
    mape_mean = compute_mape(y_true, y_mean)
    mse_mean = compute_mse(y_true, y_mean)

    if np.all(np.isnan(y_median)):
        rho_median, p_median = None, None
        mape_median, mse_median = None, None
    else:
        rho_median, p_median = compute_spearman(y_true, y_median)
        mape_median = compute_mape(y_true, y_median)
        mse_median = compute_mse(y_true, y_median)

    write_results(output_path, rho_mean, p_mean, rho_median, p_median, mape_mean, mse_mean, mape_median, mse_median, n_samples=n_samples)


if __name__ == "__main__":
    main() 